import torch
import torchvision
from torchvision import transforms
from .randomaug import RandAugment
from .mixup import BatchMixup

__all__ = ["cifar10"]


def cifar10(config):
    transform_train = transforms.Compose(
        [
            transforms.RandomCrop(32, padding=4),
            transforms.Resize(32),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ]
    )
    mixup_fn = None
    for k, v in config.Aug.items():
        if v.name.lower() == "randomaugment":
            transform_train.transforms.insert(0, RandAugment(**v.kwargs))
        elif v.name.lower() == "mixup":
            mixup_fn = BatchMixup(alpha=v.kwargs.alpha, num_classes=10)
        else:
            raise NotImplementedError

    transform_test = transforms.Compose(
        [
            transforms.Resize(32),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ]
    )
    trainset = torchvision.datasets.CIFAR10(
        root=config.train.path,
        train=True,
        download=False,
        transform=transform_train,
    )
    testset = torchvision.datasets.CIFAR10(
        root=config.test.path,
        train=False,
        download=False,
        transform=transform_test,
    )
    trainloader = torch.utils.data.DataLoader(
        trainset, batch_size=config.train.train_batch, shuffle=True, num_workers=8
    )
    testloader = torch.utils.data.DataLoader(
        testset, batch_size=config.test.test_batch, shuffle=False, num_workers=8
    )
    return trainloader, testloader, mixup_fn
